{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Feature Extraction\n",
    "===============\n",
    "\n",
    "This code will use Tensorflow/Keras to extract VGG16-FC2 features from the image datasets and save them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tensorflow.keras.applications.vgg16 import (VGG16, preprocess_input)\n",
    "from tensorflow.keras.models import Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vgg16 = VGG16()\n",
    "\n",
    "extractor = Model(\n",
    "  inputs = vgg16.input,\n",
    "  outputs = vgg16.get_layer(\"fc2\").output\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def preprocess(fp):\n",
    "    img = Image.open(fp)\n",
    "    img = img.resize((224, 224))\n",
    "    img = img.convert('RGB')\n",
    "    img = np.array(img)\n",
    "    img = np.expand_dims(img, 0)\n",
    "    img = preprocess_input(img)\n",
    "    return img"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "from PIL import Image\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import sys, os, glob\n",
    "import tempfile, shutil\n",
    "from tqdm.notebook import tqdm\n",
    "from multiprocessing import Pool"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract Target Set\n",
    "This code will process the image archives from the target set. If you have not done so already, open and run the cells in `01_download.ipynb` that download the target set.\n",
    "\n",
    "**Note**: There are 16 archives in total; processing each one may take up to an hour."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this parameter controls the number of images loaded into memory at once.\n",
    "# lower it if you are having problems.\n",
    "split_size = 5000\n",
    "\n",
    "archive_paths = sorted(glob.glob('/tf/open_images/targets/images/train_*.tar.gz'))\n",
    "pool = Pool()\n",
    "\n",
    "try:\n",
    "    if not os.path.exists('/tf/open_images/targets/features/'):\n",
    "        os.mkdir('/tf/open_images/targets/features/')\n",
    "\n",
    "    for archive_path in tqdm(archive_paths, desc='all files'):\n",
    "        ar_name = os.path.basename(archive_path).replace('.tar.gz', '')\n",
    "        feat_path = '/tf/open_images/targets/features/%s.h5' % ar_name.replace('train', 'targets')\n",
    "\n",
    "        if os.path.exists(feat_path):\n",
    "            continue\n",
    "\n",
    "        feature_store = pd.HDFStore(feat_path, complevel=9, mode='w')\n",
    "\n",
    "        tempdir = tempfile.mkdtemp()\n",
    "        archive = tarfile.open(archive_path)\n",
    "        archive.extractall(tempdir)\n",
    "\n",
    "        image_paths = sorted(glob.glob(os.path.join(tempdir, ar_name, \"*.jpg\")))\n",
    "        splits = np.array_split(np.array(image_paths), len(image_paths)/split_size)\n",
    "\n",
    "        for split in tqdm(splits, desc=ar_name):\n",
    "            image_ids = [ os.path.basename(path).replace('.jpg', '') for path in split ]\n",
    "\n",
    "            image_array = np.concatenate(pool.map(preprocess, split))\n",
    "\n",
    "            features = extractor.predict(image_array)\n",
    "            del image_array\n",
    "\n",
    "            frame = pd.DataFrame(features, index=image_ids)\n",
    "            del features\n",
    "\n",
    "            feature_store.append('df', frame)\n",
    "\n",
    "        shutil.rmtree(tempdir)\n",
    "        feature_store.close()\n",
    "\n",
    "finally:\n",
    "    pool.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract Validation Set"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "split_size = 5000\n",
    "pool = Pool()\n",
    "\n",
    "try:\n",
    "    if not os.path.exists('/tf/open_images/validation/features/'):\n",
    "        os.mkdir('/tf/open_images/validation/features/')\n",
    "\n",
    "    archive_path = '/tf/open_images/validation/images/validation.tar.gz'\n",
    "    ar_name = 'validation'\n",
    "    feat_path = '/tf/open_images/validation/features/validation.h5'\n",
    "    \n",
    "    if not os.path.exists(feat_path):\n",
    "\n",
    "        feature_store = pd.HDFStore(feat_path, complevel=9, mode='w')\n",
    "\n",
    "        tempdir = tempfile.mkdtemp()\n",
    "        archive = tarfile.open(archive_path)\n",
    "        archive.extractall(tempdir)\n",
    "\n",
    "        image_paths = sorted(glob.glob(os.path.join(tempdir, ar_name, \"*.jpg\")))\n",
    "        splits = np.array_split(np.array(image_paths), len(image_paths)/split_size)\n",
    "\n",
    "        for split in tqdm(splits, desc=ar_name):\n",
    "            image_ids = [ os.path.basename(path).replace('.jpg', '') for path in split ]\n",
    "\n",
    "            image_array = np.concatenate(pool.map(preprocess, split))\n",
    "\n",
    "            features = extractor.predict(image_array)\n",
    "            del image_array\n",
    "\n",
    "            frame = pd.DataFrame(features, index=image_ids)\n",
    "            del features\n",
    "\n",
    "            feature_store.append('df', frame)\n",
    "\n",
    "        shutil.rmtree(tempdir)\n",
    "        feature_store.close()\n",
    "\n",
    "finally:\n",
    "    pool.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract Training Set\n",
    "This code will process the images from the encoder training set. If you have not done so already, open and run the cells in `01_download.ipynb` that download the training set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists('/tf/open_images/train/features'):\n",
    "    os.mkdir('/tf/open_images/train/features')\n",
    "\n",
    "prefixes = [ \"%x\"%i for i in range(16) ]\n",
    "\n",
    "pool = Pool()\n",
    "try:\n",
    "    for prefix in tqdm(prefixes, desc='all files'):\n",
    "\n",
    "        feat_path = '/tf/open_images/train/features/train_%s.h5' % prefix\n",
    "\n",
    "        if os.path.exists(feat_path):\n",
    "            continue\n",
    "\n",
    "        feature_store = pd.HDFStore(feat_path, complevel=9, mode='w')\n",
    "\n",
    "        image_dirs = sorted(glob.glob('/tf/open_images/train/images/%s*' % prefix ))\n",
    "\n",
    "        for image_dir in tqdm(image_dirs, desc=\"prefix %s\" % prefix):\n",
    "\n",
    "            image_paths = sorted(glob.glob(os.path.join(image_dir, '*.jpg')))\n",
    "\n",
    "            image_ids = [ os.path.basename(path).replace('.jpg', '') for path in image_paths ]\n",
    "\n",
    "            image_array = np.concatenate(pool.map(preprocess, image_paths))\n",
    "\n",
    "            features = extractor.predict(image_array)\n",
    "            del image_array\n",
    "\n",
    "            frame = pd.DataFrame(features, index=image_ids)\n",
    "            del features\n",
    "\n",
    "            feature_store.append('df', frame)\n",
    "\n",
    "        feature_store.close()\n",
    "        \n",
    "finally:\n",
    "    pool.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract Extended Target Set\n",
    "This code will process the images from the extended target set. If you have not done so already, open and run the cells in `01_download.ipynb` that download the extended target set."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists('/tf/open_images/extended_targets/features'):\n",
    "    os.mkdir('/tf/open_images/extended_targets/features')\n",
    "\n",
    "prefixes = [ \"%x\"%i for i in range(16) ]\n",
    "\n",
    "pool = Pool()\n",
    "try:\n",
    "    for prefix in tqdm(prefixes, desc='all files'):\n",
    "\n",
    "        feat_path = '/tf/open_images/extended_targets/features/extended_targets_%s.h5' % prefix\n",
    "\n",
    "        if os.path.exists(feat_path):\n",
    "            continue\n",
    "\n",
    "        feature_store = pd.HDFStore(feat_path, complevel=9, mode='w')\n",
    "\n",
    "        image_dirs = sorted(glob.glob('/tf/open_images/extended_targets/images/%s*' % prefix ))\n",
    "\n",
    "        for image_dir in tqdm(image_dirs, desc=\"prefix %s\" % prefix):\n",
    "\n",
    "            image_paths = sorted(glob.glob(os.path.join(image_dir, '*.jpg')))\n",
    "\n",
    "            image_ids = [ os.path.basename(path).replace('.jpg', '') for path in image_paths ]\n",
    "\n",
    "            image_array = np.concatenate(pool.map(preprocess, image_paths))\n",
    "\n",
    "            features = extractor.predict(image_array)\n",
    "            del image_array\n",
    "\n",
    "            frame = pd.DataFrame(features, index=image_ids)\n",
    "            del features\n",
    "\n",
    "            feature_store.append('df', frame)\n",
    "\n",
    "        feature_store.close()\n",
    "        \n",
    "finally:\n",
    "    pool.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract Query Set\n",
    "This code will process the images from the query set. You may use the included query images, or add your own `.jpg` images to the `data/queries/images` folder."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_image_paths = glob.glob('/tf/primo/data/queries/images/*.jpg')\n",
    "query_ids = [os.path.basename(path).replace('.jpg','') for path in query_image_paths]\n",
    "\n",
    "query_image_array = np.concatenate([preprocess(image_path) for image_path in query_image_paths])\n",
    "query_features = extractor.predict(query_image_array)\n",
    "\n",
    "pd.DataFrame(query_features, index=query_ids).to_hdf('/tf/primo/data/queries/features.h5', key='df', mode='w')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.15+"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
